iT邦幫忙

2025 iThome 鐵人賽

DAY 7
0

前情提要

昨天一樣分步驟講解了 MHA 的運算,怕篇幅太長大腦過載,所以將程式放到今天。

1. 實作

複習一下 self-attention 程式步驟

  1. 定義最基本的 class (init + forward) → 問自己 x 輸入的維度是多少
  2. Q, K, V 都是由 x 經過線性變化來的 → 所以要定義三個 nn.Linear(hidden_size, hidden_size)
    分母的 dk 上面沒有特別描述,就是一個 scaling → hidden_size ** -0.5
  3. 在 forward 準備要做計算
    1. x 做線性轉換 → query, key, value
    2. qk 內積 → torch.matmul, 乘以 scaling
    3. softmax → 得到 attn_weights
    4. 乘以 value → 得到最終 output

開始前一樣先看一下大家名稱怎麼取

  • 輸出: linear_out, linear_o, out_proj, o_proj
  • split 完的維度: head_dim
  • head 的數量: num_attention_heads, num_heads, n_head

這裡一樣分步驟,大致上沒什麼變化,只是多了 split 這個 block,以及最後還要再過一個線性轉換,剩下比較需要思考的就是維度上的操作了
  https://ithelp.ithome.com.tw/upload/images/20250901/20168446TUoxHB2OIf.png
以下使用 linear_q, n_head, head_dim 等名稱

  1. 定義最基本的 class (init + forward) → 問自己 x 輸入的維度是多少
  2. Q, K, V 都是由 x 經過線性變化來的 → 所以要定義三個 nn.Linear(hidden_size, hidden_size)
    多了輸出 → 多宣告一個 nn.Linear
    計算 head_dim = hidden_size/n_head
    scaling 變成 head_dim ** -0.5
  3. 在 forward 準備要做計算
    1. x 做線性轉換 → query, key, value
    2. 做 split 這個 block → 先切割(view) 後 swap (transpose or permute)
    3. qk 內積 → torch.matmul, 乘以 scaling → 得到 attn_scores
    4. softmax → 得到 attn_weights
    5. 乘以 value → 得到 output
    6. 經過輸出線性變換 (先 transpose 後 view 剛好跟第二步相反 → 怎麼來的怎麼回去)
import torch
from torch import nn
import torch.nn.functional as F

# step 1
class MyMultiHeadAttention(nn.Module):
    def __init__(self):
        super().__init__()    

    def forward(self, x: torch.Tensor):
        '''
            B: batch size
            L: seq len
            D: embedding dimension
            x: (B, L, D) or (B, L, E) 簡寫每個都不太一樣
        '''
        return
import torch
from torch import nn
import torch.nn.functional as F

# step 2
class MyMultiHeadAttention(nn.Module):
    def __init__(self, hidden_size, n_head):
        super().__init__()    
        # 跟之前一樣
        self.linear_q = nn.Linear(hidden_size, hidden_size)
        self.linear_k = nn.Linear(hidden_size, hidden_size)
        self.linear_v = nn.Linear(hidden_size, hidden_size)

        # 多宣告及小地方不一樣
        self.linear_o = nn.Linear(hidden_size, hidden_size)
        self.head_dim = hidden_size // n_head
        self.scaling = self.head_dim ** -0.5

    def forward(self, x: torch.Tensor):
        '''
            B: batch size
            L: seq len
            D: embedding dimension
            x: (B, L, D) or (B, L, E) 簡寫每個都不太一樣
        '''
        return
import torch
from torch import nn
import torch.nn.functional as F

# step 3
class MyMultiHeadAttention(nn.Module):
    def __init__(self, hidden_size, n_head):
        super().__init__()    
        # 跟之前一樣
        self.linear_q = nn.Linear(hidden_size, hidden_size)
        self.linear_k = nn.Linear(hidden_size, hidden_size)
        self.linear_v = nn.Linear(hidden_size, hidden_size)

        # 多宣告及小地方不一樣
        self.linear_o = nn.Linear(hidden_size, hidden_size)
        self.head_dim = hidden_size // n_head
        self.scaling = self.head_dim ** -0.5

    def forward(self, x: torch.Tensor):
        '''
            B: batch size
            L: seq len
            D: embedding dimension
            x: (B, L, D) or (B, L, E) 簡寫每個都不太一樣
        '''
        B, L, D = x.shape
        query = self.linear_q(x)
        key = self.linear_k(x)
        value = self.linear_v(x)

        # (B, L, D) -> (B, L, n_head, head_dim) -> (B, n_head, L, head_dim)
        query = query.view(B, L, -1, self.head_dim).transpose(1, 2)
        key = key.view(B, L, -1, self.head_dim).transpose(1, 2)
        value = value.view(B, L, -1, self.head_dim).transpose(1, 2)

        # (B, n_head, L, head_dim) dot (B, n_head, head_dim, L) = (B, n_head, L, L)
        attn_scores = torch.matmul(query, key.transpose(2, 3)) * self.scaling
        attn_weights = F.softmax(aattn_scores, dim = -1)

        # (B, n_head, L, L) dot (B, n_head, L, head_dim) = (B, n_head, L, head_dim)
        attn_output = torch.matmul(attn_weights, value)
        
        # 可以自行選用其中一個 -> 怎麼來的怎麼回去
        # (B, n_head, L, head_dim) -> (B, L, n_head, head_dim) -> (B, L, D)
        attn_output = attn_output.transpose(1, 2).contiguous().view(B, L, -1)
        # attn_output = attn_output.transpose(1, 2).reshape(B, L, -1)
        attn_output = self.linear_o(attn_output)

        return attn_output

當初我在學的時候總覺得維度操作有點麻煩,但看著圖實際動筆在紙上寫一次,後來漸漸就熟悉了,雖然上面的 code 離實際應用還差一點點,不過已經可以試試看去看人家大公司寫的 code (transformers, Nemo, …),或許你會像我當初一樣忽然讀懂了。 今天就先到這裡囉~~
如果覺得對你有幫助,歡迎動動小手點個讚~


上一篇
Day6: self attention 總結 & MHA 觀念
下一篇
Day8: Padding mask
系列文
實戰派 AI 工程師帶你 0->18
圖片
  熱門推薦
圖片
{{ item.channelVendor }} | {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言